{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import os\n",
    "import cPickle\n",
    "import numpy as np\n",
    "\n",
    "CIFAR_DIR = \"./../../cifar-10-batches-py\"\n",
    "print os.listdir(CIFAR_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data(filename):\n",
    "    \"\"\"read data from data file.\"\"\"\n",
    "    with open(filename, 'rb') as f:\n",
    "        data = cPickle.load(f)\n",
    "        return data['data'], data['labels']\n",
    "\n",
    "# tensorflow.Dataset.\n",
    "class CifarData:\n",
    "    def __init__(self, filenames, need_shuffle):\n",
    "        all_data = []\n",
    "        all_labels = []\n",
    "        for filename in filenames:\n",
    "            data, labels = load_data(filename)\n",
    "            all_data.append(data)\n",
    "            all_labels.append(labels)\n",
    "        self._data = np.vstack(all_data)\n",
    "        self._data = self._data / 127.5 - 1\n",
    "        self._labels = np.hstack(all_labels)\n",
    "        print self._data.shape\n",
    "        print self._labels.shape\n",
    "        \n",
    "        self._num_examples = self._data.shape[0]\n",
    "        self._need_shuffle = need_shuffle\n",
    "        self._indicator = 0\n",
    "        if self._need_shuffle:\n",
    "            self._shuffle_data()\n",
    "            \n",
    "    def _shuffle_data(self):\n",
    "        # [0,1,2,3,4,5] -> [5,3,2,4,0,1]\n",
    "        p = np.random.permutation(self._num_examples)\n",
    "        self._data = self._data[p]\n",
    "        self._labels = self._labels[p]\n",
    "    \n",
    "    def next_batch(self, batch_size):\n",
    "        \"\"\"return batch_size examples as a batch.\"\"\"\n",
    "        end_indicator = self._indicator + batch_size\n",
    "        if end_indicator > self._num_examples:\n",
    "            if self._need_shuffle:\n",
    "                self._shuffle_data()\n",
    "                self._indicator = 0\n",
    "                end_indicator = batch_size\n",
    "            else:\n",
    "                raise Exception(\"have no more examples\")\n",
    "        if end_indicator > self._num_examples:\n",
    "            raise Exception(\"batch size is larger than all examples\")\n",
    "        batch_data = self._data[self._indicator: end_indicator]\n",
    "        batch_labels = self._labels[self._indicator: end_indicator]\n",
    "        self._indicator = end_indicator\n",
    "        return batch_data, batch_labels\n",
    "\n",
    "train_filenames = [os.path.join(CIFAR_DIR, 'data_batch_%d' % i) for i in range(1, 6)]\n",
    "test_filenames = [os.path.join(CIFAR_DIR, 'test_batch')]\n",
    "\n",
    "train_data = CifarData(train_filenames, True)\n",
    "test_data = CifarData(test_filenames, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = tf.placeholder(tf.float32, [None, 3072])\n",
    "# [None], eg: [0,5,6,3]\n",
    "y = tf.placeholder(tf.int64, [None])\n",
    "hidden1 = tf.layers.dense(x, 100, activation=tf.nn.relu)\n",
    "hidden2 = tf.layers.dense(hidden1, 100, activation=tf.nn.relu)\n",
    "hidden3 = tf.layers.dense(hidden2, 50, activation=tf.nn.relu)\n",
    "y_ = tf.layers.dense(hidden3, 10)\n",
    "\n",
    "loss = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=y_)\n",
    "# y_ -> sofmax\n",
    "# y -> one_hot\n",
    "# loss = ylogy_\n",
    "\n",
    "# indices\n",
    "predict = tf.argmax(y_, 1)\n",
    "# [1,0,1,1,1,0,0,0]\n",
    "correct_prediction = tf.equal(predict, y)\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float64))\n",
    "\n",
    "with tf.name_scope('train_op'):\n",
    "    train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init = tf.global_variables_initializer()\n",
    "batch_size = 20\n",
    "train_steps = 10000\n",
    "test_steps = 100\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    sess.run(init)\n",
    "    for i in range(train_steps):\n",
    "        batch_data, batch_labels = train_data.next_batch(batch_size)\n",
    "        loss_val, acc_val, _ = sess.run(\n",
    "            [loss, accuracy, train_op],\n",
    "            feed_dict={\n",
    "                x: batch_data,\n",
    "                y: batch_labels})\n",
    "        if (i+1) % 500 == 0:\n",
    "            print '[Train] Step: %d, loss: %4.5f, acc: %4.5f' \\\n",
    "                % (i+1, loss_val, acc_val)\n",
    "        if (i+1) % 5000 == 0:\n",
    "            test_data = CifarData(test_filenames, False)\n",
    "            all_test_acc_val = []\n",
    "            for j in range(test_steps):\n",
    "                test_batch_data, test_batch_labels \\\n",
    "                    = test_data.next_batch(batch_size)\n",
    "                test_acc_val = sess.run(\n",
    "                    [accuracy],\n",
    "                    feed_dict = {\n",
    "                        x: test_batch_data, \n",
    "                        y: test_batch_labels\n",
    "                    })\n",
    "                all_test_acc_val.append(test_acc_val)\n",
    "            test_acc = np.mean(all_test_acc_val)\n",
    "            print '[Test ] Step: %d, acc: %4.5f' % (i+1, test_acc)\n",
    "                \n",
    "                \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
